#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse
import csv
import math
import random
from bisect import bisect_left
from pathlib import Path
from typing import Dict, List, Tuple, Optional

import numpy as np

# Defaults
DEFAULT_SOURCE = "ecb"  # "btc" or "ecb"
DEFAULT_BTC_CSV = "btcusd_1-min_data.csv"
DEFAULT_BTC_COLUMN = "Close"
DEFAULT_ECB_CSV = "eurofxref-hist.csv"
DEFAULT_CURRENCIES = ["CHF", "GBP", "JPY", "USD"]
DEFAULT_RUNS = 1000
DEFAULT_R = 100.0
DEFAULT_ALPHAS = [0.1, 0.5, 0.9]
DEFAULT_T0 = 1577833200.0
DEFAULT_T1 = 1735686000.0

# Globals from data
exchange_rates: List[float] = []
lowest_price: float = 0.0
p_star: float = 0.0
h: float = 0.0
M_data: float = 0.0

# Utils
def _isnum(s: str) -> bool:
    try:
        float(s)
        return True
    except Exception:
        return False

def discretization(a: float, b: float, steps: int = 200) -> List[float]:
    return [a + i * (b - a) / steps for i in range(steps + 1)]

def linear_weight(z: float) -> float:
    """Triangular on [-1,1]."""
    return 0.0 if (z < -1.0 or z > 1.0) else (1.0 - abs(z))

# Truncated Gaussian CDF on [-1,1]
_SIGMA = 1.0 / 4.0
def _unnorm_gauss(z: float) -> float:
    return math.exp(-0.5 * (z / _SIGMA) ** 2)

def _unnorm_gauss_int(z: float) -> float:
    return math.sqrt(math.pi / 2.0) * _SIGMA * math.erf(z / (math.sqrt(2.0) * _SIGMA))

_GNORM = _unnorm_gauss_int(+1.0) - _unnorm_gauss_int(-1.0)

def gaussian_integration(z: float) -> float:
    """Normalized integral on [-1,1]."""
    zc = min(max(z, -1.0), 1.0)
    return (_unnorm_gauss_int(zc) - _unnorm_gauss_int(-1.0)) / _GNORM

# Data loading + 8-block R_y
def _postprocess_from_values(vals: List[float], tag: str) -> None:
    """Set p_star, h via 8 block maxima. Monotone filter. Set M_data."""
    global exchange_rates, lowest_price, p_star, h, M_data
    if not vals:
        raise ValueError("No numeric data loaded.")
    lo, hi = min(vals), max(vals)
    M_data = hi
    print(f"[{tag}] Loaded {len(vals)} points. lowest={lo} highest={hi} ratio={hi/lo:.6f}")

    n = len(vals)
    best: List[float] = []
    for k in range(8):
        L = n * k // 8
        R = n * (k + 1) // 8 + 1
        if L < R:
            best.append(max(vals[L:R]))
    if not best:
        best = [max(vals)]
    h = max(best) - min(best)
    p_star_local = max(best)

    filtered: List[float] = []
    run_max = -math.inf
    for v in vals:
        if v >= run_max:
            filtered.append(v)
            run_max = v
    if not filtered:
        filtered = [vals[0]]

    exchange_rates = filtered
    lowest = min(exchange_rates)
    highest = max(exchange_rates)
    print(f"[{tag}] After monotone filter: {len(exchange_rates)} points (lowest={lowest}, highest={highest})")

    lowest_price = lowest
    p_star = p_star_local

def init_exchange_rates_btc(csv_path: Optional[str],
                            column_name: str = "Close",
                            window: Tuple[float, float] = (DEFAULT_T0, DEFAULT_T1),
                            timestamp_col_index: int = 0) -> None:
    if csv_path is None:
        csv_path = DEFAULT_BTC_CSV
    p = Path(csv_path).expanduser()
    if not p.exists():
        raise FileNotFoundError(f"BTC CSV not found: {p}")

    with p.open("r", newline="") as f:
        r = csv.reader(f)
        header = next(r)
        header_norm = [c.strip() for c in header]
        header_lower = [c.lower() for c in header_norm]
        target = column_name.strip().lower()
        if target not in header_lower:
            raise KeyError(f"Column '{column_name}' not found in header {header_norm}")
        i = header_lower.index(target)
        rows = list(r)

    t0, t1 = window
    vals: List[float] = []
    for row in rows:
        if len(row) <= max(i, timestamp_col_index):
            continue
        if not (_isnum(row[timestamp_col_index]) and _isnum(row[i])):
            continue
        ts = float(row[timestamp_col_index])
        if t0 <= ts < t1:
            v = float(row[i])
            if math.isfinite(v):
                vals.append(v)

    _postprocess_from_values(vals, tag="BTC")

def init_exchange_rates_ecb(csv_path: Optional[str], currency: str) -> None:
    if csv_path is None:
        csv_path = DEFAULT_ECB_CSV
    p = Path(csv_path).expanduser()
    if not p.exists():
        raise FileNotFoundError(f"ECB CSV not found: {p}")

    with p.open("r", newline="") as f:
        r = csv.reader(f)
        header = next(r)
        header_norm = [c.strip() for c in header]
        if currency not in header_norm:
            raise KeyError(f"Currency '{currency}' not found in header {header_norm}")
        i = header_norm.index(currency)
        rows = list(r)

    vals: List[float] = []
    for row in rows:
        if len(row) <= i:
            continue
        if not _isnum(row[i]):
            continue
        v = float(row[i])
        if math.isfinite(v):
            vals.append(v)

    _postprocess_from_values(vals, tag=f"ECB:{currency}")

# Oracle on real monotone series
def above_threshold(T: float) -> float:
    """First observed ≥ T or lowest."""
    i = bisect_left(exchange_rates, T)
    if i == len(exchange_rates):
        return lowest_price
    return exchange_rates[i]

def ratio_real(T: float) -> float:
    """p_star / obtained."""
    obtained = above_threshold(T)
    return p_star / obtained if obtained > 0 else float("inf")

def profit_real(T: float) -> float:
    """T if hit else 1."""
    return T if above_threshold(T) >= T else 1.0

# Robust objectives (R, M_data)
def ratio_offline(p_star_val: float, T: float) -> float:
    return p_star_val / T if p_star_val >= T else p_star_val

def weighted_diff_robust(p_hat_val: float, p_star_val: float, h_: float, T: float, w, R: float, M: float) -> float:
    z = abs(p_star_val - p_hat_val) / h_
    base = ratio_offline(p_star_val, T)
    if p_star_val >= R:
        target = p_star_val / R
    elif p_star_val <= M / R:
        target = p_star_val
    else:
        target = 1.0
    return (base - target) * w(z)

def max_pstar_robust(p_hat_val: float, h_: float, T: float, w, R: float, M: float):
    vals = []
    for ps in discretization(p_hat_val - h_, p_hat_val + h_):
        vals.append((weighted_diff_robust(p_hat_val, ps, h_, T, w, R, M), ps))
    return max(vals, key=lambda t: t[0])

def sum_pstar_robust(p_hat_val: float, h_: float, T: float, w, R: float, M: float) -> float:
    s = 0.0
    for ps in discretization(p_hat_val - h_, p_hat_val + h_):
        s += weighted_diff_robust(p_hat_val, ps, h_, T, w, R, M)
    return s / (2.0 * h_)

def minmax_T_robust(p_hat_val: float, h_: float, w, R: float, M: float):
    cands = []
    for T in discretization(p_hat_val - h_, p_hat_val + h_):
        cands.append((max_pstar_robust(p_hat_val, h_, T, w, R, M), T))
    return min(cands, key=lambda t: t[0][0])

def minsum_T_robust(p_hat_val: float, h_: float, w, R: float, M: float):
    cands = []
    for T in discretization(p_hat_val - h_, p_hat_val + h_):
        cands.append((sum_pstar_robust(p_hat_val, h_, T, w, R, M), T))
    return min(cands, key=lambda t: t[0])

def maxcvar_T(p_hat_val: float, h_: float, alpha: float, R: float, M: float):
    """Max CVaR over T in [max(y-h,M/R), min(y+h,R)] ∪ {y-h} with truncated Gaussian CDF."""
    if not (0 <= alpha < 1):
        raise ValueError("alpha must be in [0,1)")
    ell, u = p_hat_val - h_, p_hat_val + h_
    t1, t2 = max(ell, M / R), min(u, R)

    def q_T(T: float) -> float:
        z = (min(max(T, ell), u) - p_hat_val) / h_
        return gaussian_integration(z)

    def cvar_value(T: float) -> float:
        qT = q_T(T)
        return (T * (1 - alpha - qT) + qT) / (1 - alpha)

    candidates = [(ell, ell)]
    if t1 <= t2:
        for T in discretization(t1, t2):
            candidates.append((cvar_value(T), T))
    return max(candidates, key=lambda t: t[0])

# PO baselines
def r_of_lambda(lmbda: float, M: float) -> float:
    return (math.sqrt((1 - lmbda) ** 2 + 4 * lmbda * M) - (1 - lmbda)) / (2 * lmbda)

def lambda_from_r(r_target: float, M: float) -> float:
    lo, hi = 1e-8, 1 - 1e-8
    for _ in range(80):
        mid = 0.5 * (lo + hi)
        val = r_of_lambda(mid, M)
        if val < r_target:
            lo = mid
        else:
            hi = mid
    return 0.5 * (lo + hi)

def pareto_alg(lam: float, R: float, y: float, M: float) -> float:
    """PO1 schedule."""
    return lam * R + ((1 - lam) * y * R) / M

# Prediction sampler
def sample_prediction_from_data() -> float:
    """y = p_star + h*x, x ~ N(0, 0.5^2) on [-1,1]."""
    if h <= 0:
        return p_star
    while True:
        x = random.gauss(0.0, 0.5)
        if -1.0 <= x <= 1.0:
            return p_star + h * x

# Bootstrap CI
def mean_ci_pm(samples: List[float], conf: float = 0.95, n_boot: int = 2000, rng: Optional[np.random.Generator] = None):
    if rng is None:
        rng = np.random.default_rng(42)
    arr = np.asarray(samples, dtype=float)
    n = arr.size
    mean = float(np.mean(arr)) if n else float("nan")
    if n <= 1:
        return mean, 0.0
    idx = rng.integers(0, n, size=(n_boot, n))
    boots = np.mean(arr[idx], axis=1)
    alpha = (1.0 - conf) / 2.0
    q_lo, q_hi = np.quantile(boots, [alpha, 1.0 - alpha])
    halfwidth = max(mean - float(q_lo), float(q_hi) - mean)
    return mean, halfwidth

# Experiment
def run_experiment_table(runs: int, R: float, alphas: List[float], rng_seed: int = 42):
    """Run loops, compute thresholds, evaluate, return summaries."""
    random.seed(rng_seed)
    np.random.seed(rng_seed)
    rng = np.random.default_rng(rng_seed)

    M = M_data
    lam = lambda_from_r(R, M)

    algs = ["δT", "PO1", "PO2", "MAX", "AVG"] + [f"CVaR_{a}" for a in alphas]
    ratio_runs: Dict[str, List[float]] = {k: [] for k in algs}
    profit_runs: Dict[str, List[float]] = {k: [] for k in algs}

    for _ in range(runs):
        y = sample_prediction_from_data()
        y = max(y, 1e-12)
        h_loc = max(h, 1e-12)
        ell, u = y - h_loc, y + h_loc  # local bounds

        T_deltaT = y - h_loc
        T_PO1 = pareto_alg(lam, R, y, M)
        T_PO2 = y
        T_MAX = minmax_T_robust(y, h_loc, linear_weight, R, M)[1]
        T_AVG = minsum_T_robust(y, h_loc, linear_weight, R, M)[1]
        T_CVaR = {a: maxcvar_T(y, h_loc, a, R, M)[1] for a in alphas}

        ratio_runs["δT"].append(ratio_real(T_deltaT))
        ratio_runs["PO1"].append(ratio_real(T_PO1))
        ratio_runs["PO2"].append(ratio_real(T_PO2))
        ratio_runs["MAX"].append(ratio_real(T_MAX))
        ratio_runs["AVG"].append(ratio_real(T_AVG))
        for a in alphas:
            ratio_runs[f"CVaR_{a}"].append(ratio_real(T_CVaR[a]))

        profit_runs["δT"].append(profit_real(T_deltaT))
        profit_runs["PO1"].append(profit_real(T_PO1))
        profit_runs["PO2"].append(profit_real(T_PO2))
        profit_runs["MAX"].append(profit_real(T_MAX))
        profit_runs["AVG"].append(profit_real(T_AVG))
        for a in alphas:
            profit_runs[f"CVaR_{a}"].append(profit_real(T_CVaR[a]))

    ratios_summary: Dict[str, Tuple[float, float]] = {}
    profits_summary: Dict[str, Tuple[float, float]] = {}
    for k in algs:
        m, hw = mean_ci_pm(ratio_runs[k], rng=rng)
        ratios_summary[k] = (m, hw)
        m, hw = mean_ci_pm(profit_runs[k], rng=rng)
        profits_summary[k] = (m, hw)
    return ratios_summary, profits_summary

def print_table_block(title: str,
                      ratios: Dict[str, Tuple[float, float]],
                      profits: Dict[str, Tuple[float, float]],
                      alg_order: List[str]) -> None:
    print(f"\n% ---- {title} ----")
    print("Algorithm " + " & ".join(alg_order) + r" \\")
    print("Avg Ratio " + " & ".join(f"{ratios[a][0]:.6f} ± {ratios[a][1]:.6f}" for a in alg_order) + r" \\")
    print("Avg Profit " + " & ".join(f"{profits[a][0]:.6f} ± {profits[a][1]:.6f}" for a in alg_order) + r" \\")

# Main
def build_parser() -> argparse.ArgumentParser:
    p = argparse.ArgumentParser(description="Real-data One-Max with 8-block R_y; averages + 95% CI (±).", add_help=True)
    p.add_argument("--source", choices=["btc", "ecb"], default=DEFAULT_SOURCE)
    p.add_argument("--csv", default=None)
    p.add_argument("--column", default=DEFAULT_BTC_COLUMN)
    p.add_argument("--currencies", nargs="+", default=DEFAULT_CURRENCIES)
    p.add_argument("--runs", type=int, default=DEFAULT_RUNS)
    p.add_argument("--R", type=float, default=DEFAULT_R)
    p.add_argument("--t0", type=float, default=DEFAULT_T0)
    p.add_argument("--t1", type=float, default=DEFAULT_T1)
    p.add_argument("--alphas", type=float, nargs="+", default=DEFAULT_ALPHAS)
    return p

def main():
    parser = build_parser()
    args = parser.parse_args()

    alg_order = ["δT", "PO1", "PO2", "MAX", "AVG"] + [f"CVaR_{a}" for a in args.alphas]

    if args.source == "btc":
        init_exchange_rates_btc(csv_path=args.csv, column_name=args.column, window=(args.t0, args.t1))
        print(f"Using R={args.R:.4f}, M (from data)={M_data:.4f}, p_star={p_star:.4f}, h={h:.4f}")
        ratios, profits = run_experiment_table(runs=args.runs, R=args.R, alphas=args.alphas)
        print_table_block("BTC (averages over runs)", ratios, profits, alg_order)
    else:
        for cur in args.currencies:
            init_exchange_rates_ecb(csv_path=args.csv, currency=cur)
            print(f"[{cur}] Using R={args.R:.4f}, M (from data)={M_data:.4f}, p_star={p_star:.4f}, h={h:.4f}")
            ratios, profits = run_experiment_table(runs=args.runs, R=args.R, alphas=args.alphas)
            print_table_block(f"ECB {cur} (averages over runs)", ratios, profits, alg_order)

if __name__ == "__main__":
    main()
